from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import pandas as pd
import ast
import re
import string
import os
import requests
import time
import json
from openai import OpenAI
from tqdm import tqdm

class ReferencesEvaluator:
    def __init__(self, s2_api_key, client):
        self.s2_api_key = s2_api_key
        self.client = client

    def _is_response_abstained(self, prompt, response):
        judge_prompt = f"""
        You are given a prompt and its response . Your task is to judge whether the response to the prompt is an abstention from answering or not? Just answer with 'yes' or 'no'. 'yes' if it is an abstention, 'no' if it is not an abstention and it seems like an answer.
        prompt: {prompt}
        response: {response}

        Return your generation in the following format. Do not include any other text:

        abstention: [your judgement here]

        """

        completion = self.client.chat.completions.create(
            model="gpt-4.1-mini",
            messages=[
                {"role": "system", "content": "You are a helpful assistant that judges a pair of prompts and responses and checks whether the response is an absention from answering the prompt. Always respond with just yes or no, nothing else."},
                {"role": "user", "content": judge_prompt},
            ],
        )

        try:
            abstention = completion.choices[0].message.content.strip().split("abstention: ")[1]
        except:
            abstention = completion.choices[0].message.content.strip()
        abstention = abstention.strip().lower()

        if abstention == "yes":
            return True
        elif abstention == "no":
            return False
        else:
            print(f"Unexpected response from abstention check: {abstention}")
            return False


    def _generate_response(self, prompt):
        generation_prompt = f"""
        You are given a text. Your task is to extract titles in format 'Title: <title>;'. Only return titles.
        text: {prompt}

        Return your generation in the following format. Do not include any other text:

        generation: [your judgement here]

        """

        completion = self.client.chat.completions.create(
            model="gpt-4.1-mini",
            messages=[
                {"role": "system", "content": "You are a reference extraction assistant. Extract paper titles from text."},
                {"role": "user", "content": generation_prompt},
            ],
        )

        try:
            generation = completion.choices[0].message.content.strip().split("generation: ")[1]
        except:
            generation = completion.choices[0].message.content.strip().split("generation:")[1]
       
        return generation

    def _extract_atomic_units(self, response):
        if not response:
            return []
        matches = re.findall(r'Title: (.*?);', response)
        return [title.strip() for title in matches]

    def _query_semantic_scholar(self, title):
        base_url = "https://api.semanticscholar.org/graph/v1/paper/"
        headers = {"x-api-key": self.s2_api_key}
        
        if not title or not isinstance(title, str):
            return None

        query_title = '+'.join(title.split())
        url = f"{base_url}search?query={query_title}&fields=url"

        time.sleep(1)
        response = requests.get(url, headers=headers)

        if response.status_code == 200:
            title_search_results = response.json()
            if 'data' in title_search_results and title_search_results['data']:
                paper_id = title_search_results['data'][0]['paperId']
                paper_url = f"{base_url}{paper_id}"

                retry_count = 0
                while retry_count < 3:
                    paper_response = requests.get(paper_url, headers=headers)
                    if paper_response.status_code == 200:
                        data = paper_response.json()
                        return data.get('title', '')
                    elif paper_response.status_code == 429:
                        retry_count += 1
                        time.sleep(5)
                    else:
                        return ''
        return ''

    def _process_titles(self, titles):
        responses = []
        for title in titles:
            response = self._query_semantic_scholar(title=title)
            responses.append(response)
        return responses

    def _clean_text(self, text):
        if not isinstance(text, str) or text is None:
            return ""  
        return text.translate(str.maketrans('', '', string.punctuation)).strip().lower()

    def evaluate_references(self, filename, json_input, output_directory):
        output_list = []

        print(f"Processing {filename}...")

        with open(json_input, "r", encoding="utf-8") as f:
            json_input_ = json.load(f)


        for obj in tqdm(json_input_):
            prompt = obj['Prompt']
            responses = obj['Responses']

            for response in responses:
                if self._is_response_abstained(prompt, response):
                    generated_response = ""
                    atomic_units = []
                    s2_titles = []
                    hallucinated_atomic_units = []
                else:
                    generated_response = self._generate_response(response)
                    atomic_units = self._extract_atomic_units(generated_response)

                    s2_titles = self._process_titles(atomic_units) if isinstance(atomic_units, list) else []
                    
                    hallucinated_atomic_units = [
                        a for a, r in zip(atomic_units, s2_titles)
                        if self._clean_text(a) != self._clean_text(r)
                    ]

                output_list.append({
                    prompt: {
                        "response": response,
                        "parsed_response": generated_response,
                        "atomic_units": atomic_units,
                        "s2_titles": s2_titles,
                        "hallucinated_atomic_units": hallucinated_atomic_units
                    }
                })

        output_file_path = os.path.join(output_directory, filename)
        with open(output_file_path, "w", encoding="utf-8") as f:
            json.dump(output_list, f, ensure_ascii=False, indent=2)
        return output_file_path